package com.stars.easyms.base.asynchronous;

import com.stars.easyms.base.util.DefaultThreadFactory;
import com.stars.easyms.base.util.EasyMsThreadPoolExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.lang.NonNull;

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.*;
import java.util.function.BooleanSupplier;
import java.util.stream.IntStream;

/**
 * <p>className: AbstractAsynchronousTask</p>
 * <p>description: 异步任务处理抽象类</p>
 *
 * @author guoguifang
 * @date 2019-08-22 13:54
 * @since 1.3.0
 */
public abstract class AbstractAsynchronousTask<T> implements AsynchronousTask<T> {

    /**
     * 默认失败重试次数
     */
    private static final int DEFAULT_FAIL_RETRY_COUNT = 3;

    /**
     * 默认最大线程数量
     */
    private static final int DEFAULT_MAX_WORKER_COUNT = 5;

    /**
     * 默认任务数每超过该值时新增一个线程
     */
    private static final int DEFAULT_SINGLE_MAX_WORKER_COUNT = 100;

    /**
     * 默认wait时间，单位毫秒
     */
    private static final long DEFAULT_WAIT_TIME = 60000;

    /**
     * 错误数据保存时的key值
     */
    private static final String ERROR_DATA_KEY_OBJECT = "object";

    /**
     * 错误数据保存时的时间的key值
     */
    private static final String ERROR_DATA_KEY_TIME = "time";

    /**
     * 默认的队列填充对象
     */
    private static final Object FILL_OBJECT = new Object();

    /**
     * 当个队列最大容量
     */
    private static final int SINGLE_QUEUE_CAPACITY = 8192;

    /**
     * 基础队列，作为copy的样本
     */
    private static final Object[] BASE_QUEUE = new Object[SINGLE_QUEUE_CAPACITY];

    // 将基础列队填充满
    static {
        Arrays.fill(BASE_QUEUE, FILL_OBJECT);
    }

    /**
     * 是否批量
     */
    boolean batch;

    /**
     * 每次处理任务的最大数量，默认非批量，默认值为1
     */
    int maxExecuteCountPerTime = 1;

    /**
     * 是否已经完成初始化
     */
    private boolean inited;

    /**
     * 日志对象
     */
    protected Logger logger;

    /**
     * 存放数据的节点数组
     */
    private Map<Integer, Object[]> queue;

    /**
     * 最大允许容量，默认Integer.MAX_VALUE
     */
    private int capacity;

    /**
     * 工作线程索引数组
     */
    private int[] workerIndexArray;

    /**
     * 工作线程
     */
    private final ThreadLocal<Worker> workerThreadLocal = new ThreadLocal<>();

    /**
     * 存放数据的索引
     */
    private final AtomicLong putIndex = new AtomicLong(-1);

    /**
     * 取数据的索引
     */
    private final AtomicLong takeIndex = new AtomicLong(-1);

    /**
     * 默认的takeIndex缓存大小
     */
    private static final int TOOK_INDEX_CACHE_SIZE = 2048;

    /**
     * 已获取但并未保存值的takeIndex放入缓存让核心线程处理(非加锁)
     */
    private long[] tookIndexCache;

    /**
     * takeIndex缓存索引(非加锁)
     */
    private final AtomicLong tookIndexCacheIndex = new AtomicLong();

    /**
     * 已获取但并未保存值的takeIndex放入缓存让核心线程处理(加锁)
     */
    private final Set<Long> tookIndexWithLockCache = Collections.synchronizedNavigableSet(new TreeSet<>());

    /**
     * 默认的takeIndex缓存大小
     */
    private static final int NEED_REMOVE_QUEUE_INDEX_CACHE_SIZE = 1024;

    /**
     * 需要删除掉的队列索引数组
     */
    private int[] needRemoveQueueIndexCache;

    /**
     * 是否允许进入非核心线程创建操作
     */
    private final AtomicBoolean nonCoreWorkerCreateAccess = new AtomicBoolean(true);

    /**
     * 最后一次非核心线程检查时间
     */
    private long lastNonCoreWorkerCountCheckTime;

    /**
     * 是否允许进入唤醒操作
     */
    private final AtomicBoolean signalAccess = new AtomicBoolean(true);

    /**
     * 是否允许进入删除队列操作
     */
    private final AtomicBoolean removeQueueAccess = new AtomicBoolean(true);

    /**
     * 当前队列中的数据总数量
     */
    private final LongAdder count = new LongAdder();

    /**
     * 错误数据列表
     */
    private List[] errorDataListArray;

    /**
     * 失败重试次数
     */
    int failRetryCount;

    /**
     * 失败超过最大重试次数后丢弃的任务数量
     */
    final LongAdder failedDiscardCount = new LongAdder();

    /**
     * 工作线程池
     */
    private EasyMsThreadPoolExecutor executor;

    /**
     * 工作线程数量
     */
    private AtomicInteger workerCount;

    /**
     * 最大工作线程数量
     */
    private int maxWorkerCount;

    /**
     * 每多超过一次该值就多创建一个非核心线程处理
     */
    private int singleWorkerThreshold;

    /**
     * 队列空闲等待锁
     */
    private final Lock mainLock = new ReentrantLock();

    /**
     * 核心线程空队列时空闲等待锁
     */
    private final Lock notEmptyLock = new ReentrantLock();

    /**
     * 核心线程线程空队列时空闲等待锁监控
     */
    private final Condition notEmpty = notEmptyLock.newCondition();

    /**
     * 核心线程是否处于空队列等待中
     */
    private boolean emptyWait;

    /**
     * 核心线程是否处于非空队列唤醒状态中
     */
    private boolean notEmptySignal;

    /**
     * 队列满时等待锁
     */
    private final Lock notFullLock = new ReentrantLock();

    /**
     * 队列满时等待锁监控
     */
    private final Condition notFull = notFullLock.newCondition();

    /**
     * 由于队列满的原因而处于锁等待状态的线程数量
     */
    private final LongAdder fullLockThreadCount = new LongAdder();

    @Override
    public void add(T t) {
        if (this.queue == null) {
            this.mainLock.lock();
            try {
                if (this.queue == null) {
                    this.logger = LoggerFactory.getLogger(this.getClass());
                    this.failRetryCount = Math.max(failRetryCount(), 0);
                    this.errorDataListArray = new List[failRetryCount];
                    for (int failCount = 0; failCount < failRetryCount; failCount++) {
                        this.errorDataListArray[failCount] = Collections.synchronizedList(new ArrayList<>());
                    }
                    this.capacity = Math.max(capacity(), SINGLE_QUEUE_CAPACITY);
                    this.singleWorkerThreshold = Math.max(singleWorkerThreshold(), 1);
                    this.batchInit();
                    this.workerCount = new AtomicInteger(1);
                    this.maxWorkerCount = Math.max(maxWorkerCount(), 1);
                    this.maxWorkerCount = isOrder() && !batch ? 1 : this.maxWorkerCount;
                    this.executor = new EasyMsThreadPoolExecutor(1, maxWorkerCount, 0L, TimeUnit.SECONDS,
                            new DefaultThreadFactory().setNameFormat("easyms-asynWorker-" + getName() + (this.maxWorkerCount == 1 ? "" : "-%d")).build());
                    this.workerIndexArray = IntStream.range(0, this.maxWorkerCount).toArray();
                    this.tookIndexCache = new long[TOOK_INDEX_CACHE_SIZE];
                    Arrays.fill(this.tookIndexCache, -1);
                    this.needRemoveQueueIndexCache = new int[NEED_REMOVE_QUEUE_INDEX_CACHE_SIZE];
                    Arrays.fill(this.needRemoveQueueIndexCache, -1);
                    this.queue = new ConcurrentHashMap<>(32);
                    this.queue.put(0, newArray());
                    this.executor.execute(new Worker(true));
                    this.inited = true;
                }
            } finally {
                this.mainLock.unlock();
            }
        }
        this.enqueue(t);
    }

    protected class Worker implements Runnable {

        private final boolean isCore;

        private final int workerIndex;

        private final Lock lock = new ReentrantLock();

        private final Condition condition = lock.newCondition();

        private Worker(boolean isCore) {
            this.isCore = isCore;
            this.workerIndex = borrowWorkerIndex();
        }

        @Override
        public void run() {
            workerThreadLocal.set(this);
            while (this.isCore || count.sum() > workerIndex * singleWorkerThreshold) {
                if (waitForNoData()) {
                    continue;
                }
                while (count.sum() > 0) {
                    handleQueue(isCore);
                }
                if (this.isCore && failRetryCount > 0 && !errorDataIsEmpty()) {
                    coreWorkerHandleErrorData();
                }
            }
            giveBackWorkerIndex(this.workerIndex);
            workerCount.decrementAndGet();
            workerThreadLocal.remove();
        }

        private boolean waitForNoData() {
            if (this.isCore && count.sum() == 0 && errorDataIsEmpty()) {
                notEmptyLock.lock();
                try {
                    emptyWait = true;
                    while (count.sum() == 0 && errorDataIsEmpty()) {
                        notEmptySignal = false;
                        notEmpty.awaitUninterruptibly();
                    }
                } finally {
                    notEmptyLock.unlock();
                    emptyWait = false;
                    notEmptySignal = false;
                }
                return true;
            }
            return false;
        }

        private boolean errorDataIsEmpty() {
            for (int failCount = 0; failCount < failRetryCount; failCount++) {
                if (!errorDataListArray[failCount].isEmpty()) {
                    return false;
                }
            }
            return true;
        }

        @SuppressWarnings("unchecked")
        private void coreWorkerHandleErrorData() {
            Long waitTime = DEFAULT_WAIT_TIME;
            for (int failCount = failRetryCount - 1; failCount >= 0; failCount--) {
                List<Map<String, Object>> currentErrorDataList = errorDataListArray[failCount];
                if (currentErrorDataList.isEmpty()) {
                    continue;
                }

                Iterator<Map<String, Object>> iterator = currentErrorDataList.iterator();
                List<T> list = new ArrayList<>();
                while (iterator.hasNext()) {
                    Map<String, Object> errorData = iterator.next();
                    T t = (T) errorData.get(ERROR_DATA_KEY_OBJECT);
                    Long time = (Long) errorData.get(ERROR_DATA_KEY_TIME);
                    Long overTime = System.currentTimeMillis() - time;
                    if (overTime < DEFAULT_WAIT_TIME) {
                        if (overTime < waitTime) {
                            waitTime = overTime;
                        }
                        continue;
                    }
                    iterator.remove();

                    if (batch) {
                        list.add(t);
                        batchHandleErrorData(list, failCount, true);
                    } else {
                        handleErrorData(t, failCount);
                    }
                }
                if (batch) {
                    batchHandleErrorData(list, failCount, false);
                }
            }
            waitForHandleErrorData(waitTime);
        }

        private void waitForHandleErrorData(long waitTime) {
            if (this.isCore && count.sum() == 0) {
                notEmptyLock.lock();
                try {
                    emptyWait = true;
                    long nanos = TimeUnit.MILLISECONDS.toNanos(waitTime);
                    while (count.sum() == 0 && nanos > 0) {
                        notEmptySignal = false;
                        try {
                            nanos = notEmpty.awaitNanos(nanos);
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                        }
                    }
                } finally {
                    notEmptyLock.unlock();
                    emptyWait = false;
                    notEmptySignal = false;
                }
            }
        }

        public void await(BooleanSupplier waitCondition) {
            if (waitCondition.getAsBoolean()) {
                this.lock.lock();
                try {
                    while (waitCondition.getAsBoolean()) {
                        try {
                            this.condition.await(1, TimeUnit.SECONDS);
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                        }
                    }
                } finally {
                    this.lock.unlock();
                }
            }
        }

        public void signal() {
            this.lock.lock();
            try {
                this.condition.signal();
            } finally {
                this.lock.unlock();
            }
        }

        /**
         * 从线程索引池中借一个线程索引值，因为永远单线程创建因此无需加锁
         */
        private int borrowWorkerIndex() {
            for (int i = 0; i < maxWorkerCount; i++) {
                if (workerIndexArray[i] != -1) {
                    int localWorkerIndex = workerIndexArray[i];
                    workerIndexArray[i] = -1;
                    return localWorkerIndex;
                }
            }
            return maxWorkerCount + 1;
        }

        /**
         * 归还线程索引到线程索引池
         */
        private void giveBackWorkerIndex(int workerIndex) {
            workerIndexArray[workerIndex] = workerIndex;
        }
    }

    protected Worker getCurrentWorker() {
        return workerThreadLocal.get();
    }

    /**
     * 失败后重试次数(默认3)
     */
    protected int failRetryCount() {
        return DEFAULT_FAIL_RETRY_COUNT;
    }

    /**
     * 处理任务的最大线程数量(默认5)
     */
    protected int maxWorkerCount() {
        return DEFAULT_MAX_WORKER_COUNT;
    }

    /**
     * 容器最大量
     */
    protected int capacity() {
        return Integer.MAX_VALUE;
    }

    /**
     * 每超过该数值就新创建一个worker进行工作
     */
    protected int singleWorkerThreshold() {
        return DEFAULT_SINGLE_MAX_WORKER_COUNT;
    }

    /**
     * 是否强制按照顺序执行（适合maxWorkerCount超过1的），默认否
     */
    public boolean isOrder() {
        return false;
    }

    public String getName() {
        return this.getClass().getSimpleName();
    }

    /**
     * 额外的初始化方法
     */
    void batchInit() {
        // Intentionally blank
    }

    /**
     * 处理队列中数据
     */
    abstract void handleQueue(boolean isCore);

    /**
     * 批量处理错误数据
     *
     * @param list        错误数据集合
     * @param failCount   失败次数
     * @param isCheckSize 校验list的size是否已达临界值
     */
    abstract void batchHandleErrorData(List<T> list, int failCount, boolean isCheckSize);

    /**
     * 处理错误数据
     *
     * @param t         错误数据
     * @param failCount 失败次数
     */
    abstract void handleErrorData(T t, int failCount);

    @SuppressWarnings("unchecked")
    void addErrorData(int failCount, T t) {
        if (failRetryCount > 0) {
            List<Map<String, Object>> list = errorDataListArray[failCount];
            Map<String, Object> errorData = new HashMap<>(2);
            errorData.put(ERROR_DATA_KEY_OBJECT, t);
            errorData.put(ERROR_DATA_KEY_TIME, System.currentTimeMillis());
            list.add(errorData);
        }
    }

    private void enqueue(T t) {
        if (t == null) {
            throw new NullPointerException();
        }
        while (count.sum() >= capacity) {
            boolean isLock = false;
            notFullLock.lock();
            try {
                if (count.sum() >= capacity) {
                    isLock = true;
                    fullLockThreadCount.increment();
                    notFull.awaitUninterruptibly();
                }
            } finally {
                notFullLock.unlock();
                if (isLock) {
                    fullLockThreadCount.increment();
                }
            }
        }

        long currentPutIndex = putIndex.incrementAndGet();
        Object[] currentPutQueue = getQueue((int) (currentPutIndex / SINGLE_QUEUE_CAPACITY));
        currentPutQueue[(int) (currentPutIndex % SINGLE_QUEUE_CAPACITY)] = t;
        count.increment();

        // 打断所有空闲中的线程以及唤醒核心线程
        if (emptyWait && !notEmptySignal && signalAccess.compareAndSet(true, false)) {

            // 唤醒核心线程
            notEmptyLock.lock();
            try {
                notEmpty.signal();
                notEmptySignal = true;
            } finally {
                notEmptyLock.unlock();
            }
            signalAccess.set(true);
        }

        // 判断当前队列是否已经排队过多，如果排队过多新增非核心工作线程
        checkNonCoreWorkerCount(250);
    }

    @SuppressWarnings("unchecked")
    T dequeue(boolean isCore) {
        // 判断当前队列是否已经排队过多，如果排队过多新增非核心工作线程
        checkNonCoreWorkerCount(1000);

        // 进入删除无用队列的操作
        accessRemoveQueue();

        long currentTakeIndex = isCore ? getTakeIndexByCache() : -1;
        if (currentTakeIndex == -1) {
            if (count.sum() == 0 || takeIndex.get() >= putIndex.get()) {
                return null;
            }
            currentTakeIndex = takeIndex.incrementAndGet();
        }
        int queueIndex = (int) (currentTakeIndex / SINGLE_QUEUE_CAPACITY);
        int singleQueueIndex = (int) (currentTakeIndex % SINGLE_QUEUE_CAPACITY);
        Object[] currentTakeQueue = getQueue(queueIndex);
        Object obj = currentTakeQueue[singleQueueIndex];
        if (obj == FILL_OBJECT) {
            saveTookIndex(currentTakeIndex);
            return null;
        }
        if (obj == null) {
            return null;
        }

        // 若获取到队列数据则将原队列致空并将待执行数量减一
        currentTakeQueue[singleQueueIndex] = null;
        count.decrement();

        // 若有因为容量满引起的阻塞时需唤醒其中一个
        if (isCore && fullLockThreadCount.sum() > 0) {
            notFullLock.lock();
            try {
                notFull.signalAll();
            } finally {
                notFullLock.unlock();
            }
        }

        // 如果到了下一个队列时需删除上一个队列
        if (singleQueueIndex == 0 && queueIndex > 0) {
            removeQueue(queueIndex - 1, true);
        }
        return (T) obj;
    }

    /**
     * 判断当前队列是否已经排队过多，如果排队过多新增非核心工作线程
     */
    private void checkNonCoreWorkerCount(long minIntervalTime) {
        long currTime = System.currentTimeMillis();
        if (currTime - lastNonCoreWorkerCountCheckTime > minIntervalTime && nonCoreWorkerCreateAccess.compareAndSet(true, false)) {
            while (workerCount.get() < maxWorkerCount && count.sum() > workerCount.get() * singleWorkerThreshold) {
                workerCount.getAndIncrement();
                executor.execute(new Worker(false));
            }
            lastNonCoreWorkerCountCheckTime = currTime;
            nonCoreWorkerCreateAccess.set(true);
        }
    }

    /**
     * 删除无用的队列
     */
    private void accessRemoveQueue() {
        if (removeQueueAccess.compareAndSet(true, false)) {
            for (int i = 0; i < NEED_REMOVE_QUEUE_INDEX_CACHE_SIZE; i++) {
                if (this.needRemoveQueueIndexCache[i] != -1) {
                    boolean removeResult = removeQueue(this.needRemoveQueueIndexCache[i], false);
                    if (removeResult) {
                        deleteRemoveQueueIndex(i);
                    }
                }
            }
            removeQueueAccess.set(true);
        }
    }

    private long getTakeIndexByCache() {
        long currTakeIndex;
        for (int i = 0; i < TOOK_INDEX_CACHE_SIZE; i++) {
            if (this.tookIndexCache[i] != -1) {
                currTakeIndex = this.tookIndexCache[i];
                this.tookIndexCache[i] = -1;
                return currTakeIndex;
            }
        }
        if (!this.tookIndexWithLockCache.isEmpty()) {
            synchronized (this.tookIndexWithLockCache) {
                Iterator<Long> iterator = this.tookIndexWithLockCache.iterator();
                if (iterator.hasNext()) {
                    currTakeIndex = iterator.next();
                    iterator.remove();
                    return currTakeIndex;
                }
            }
        }
        return -1;
    }

    private Object[] getQueue(int queueIndex) {
        Object[] currentQueue = this.queue.get(queueIndex);
        if (currentQueue == null) {
            currentQueue = this.queue.computeIfAbsent(queueIndex, key -> newArray());
        }
        return currentQueue;
    }

    private boolean removeQueue(int queueIndex, boolean isSave) {
        Object[] lastQueue = this.queue.get(queueIndex);
        boolean flag = true;
        for (int i = 0; i < SINGLE_QUEUE_CAPACITY; i++) {
            if (lastQueue[i] != null) {
                flag = false;
                if (isSave) {
                    saveNeedRemoveQueueIndex(queueIndex);
                }
                break;
            }
        }
        if (flag) {
            this.queue.remove(queueIndex);
        }
        return flag;
    }

    private Object[] newArray() {
        Object[] objects = new Object[SINGLE_QUEUE_CAPACITY];
        System.arraycopy(BASE_QUEUE, 0, objects, 0, SINGLE_QUEUE_CAPACITY);
        return objects;
    }

    /**
     * 保存tookIndex，优先非加锁缓存，当非加锁缓存满了以后再放入加锁缓存
     */
    private void saveTookIndex(long tookIndex) {
        int currIndex = (int) (tookIndexCacheIndex.getAndIncrement() % TOOK_INDEX_CACHE_SIZE);
        if (this.tookIndexCache[currIndex] == -1) {
            this.tookIndexCache[currIndex] = tookIndex;
        } else {
            tookIndexWithLockCache.add(tookIndex);
        }
    }

    /**
     * 保存需要remove的queue索引
     */
    private void saveNeedRemoveQueueIndex(int queueIndex) {
        mainLock.lock();
        try {
            for (int i = 0; i < NEED_REMOVE_QUEUE_INDEX_CACHE_SIZE; i++) {
                if (this.needRemoveQueueIndexCache[i] == -1) {
                    this.needRemoveQueueIndexCache[i] = queueIndex;
                    break;
                }
            }
        } finally {
            mainLock.unlock();
        }
    }

    private void deleteRemoveQueueIndex(int index) {
        mainLock.lock();
        try {
            this.needRemoveQueueIndexCache[index] = -1;
        } finally {
            mainLock.unlock();
        }
    }

    public int getFailRetryCount() {
        return failRetryCount;
    }

    public long getFailedDiscardCount() {
        return failedDiscardCount.sum();
    }

    /**
     * 判断是否已经初始化
     */
    public boolean isInit() {
        return inited;
    }

    /**
     * 判断是否是批量任务
     */
    public boolean isBatch() {
        return batch;
    }

    public int getMaxExecuteCountPerTime() {
        return maxExecuteCountPerTime;
    }

    /**
     * 获取任务最大容量
     */
    public int getCapacity() {
        return capacity;
    }

    /**
     * 获取当前任务数量
     */
    public long getQueueCount() {
        return count.sum();
    }

    /**
     * 获取任务执行总量
     */
    public long getTotalCount() {
        return putIndex.get() + 1;
    }

    /**
     * 获取已经完成的任务执行总量
     */
    public long getCompletedCount() {
        return takeIndex.get() + 1;
    }

    /**
     * 获取当前工作线程数量
     */
    public int getWorkerCount() {
        return workerCount.get();
    }

    /**
     * 获取最大允许工作线程数量
     */
    public int getMaxWorkerCount() {
        return maxWorkerCount;
    }

    /**
     * 获取当任务满时被锁的任务总数量
     */
    public long getFullLockThreadCount() {
        return fullLockThreadCount.sum();
    }

    @NonNull
    public Map<String, Integer> getErrorDataCountMap() {
        int errorCount = 0;
        Map<String, Integer> errorDataCountMap = new LinkedHashMap<>();
        for (int failCount = 0; failCount < failRetryCount; failCount++) {
            int currErrorCount = this.errorDataListArray[failCount].size();
            errorDataCountMap.put("失败" + (failCount + 1) + "次", currErrorCount);
            errorCount += currErrorCount;
        }
        return errorCount > 0 ? errorDataCountMap : Collections.emptyMap();
    }

    /**
     * 当创建对象时把该对象放入监控里
     */
    protected AbstractAsynchronousTask() {
        AsynchronousTaskHolder.addAsynchronousTask(this);
    }

}